/* Copyright 2012 Tobias Marschall
 * 
 * This file is part of CLEVER.
 * 
 * CLEVER is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * CLEVER is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with CLEVER.  If not, see <http://www.gnu.org/licenses/>.
 */

#include <iostream>
#include <fstream>
#include <sstream>
#include <vector>
#include <limits>
#include <cassert>
#include <iomanip>
#include <ctime>

#include <boost/iostreams/device/file.hpp>
#include <boost/iostreams/filtering_stream.hpp>
#include <boost/iostreams/filter/gzip.hpp>
#include <boost/program_options.hpp>

#include "SortedBamReader.h"
#include "GroupWiseBamReader.h"
#include "ThreadPool.h"
#include "ShortDnaSequence.h"
#include "ReadGroups.h"
#include "VersionInfo.h"

using namespace std;
namespace po = boost::program_options;
namespace io = boost::iostreams;

void usage(const char* name, const po::options_description& options_desc) {
	cerr << "Usage: " << name << " [options] <output.1.fastq(.gz)> <output.2.fastq(.gz)>" << endl;
	cerr << endl;
	cerr << "Reads a sorted BAM file (also parsing the BWA-specific X0/X1 tags)" << endl;
	cerr << "and extract all bad or otherwise suspicious alignments." << endl;
	cerr << endl;
	cerr << "If both given filenames end on \".gz\", they are automatically gzipped." << endl;
	cerr << endl;
	cerr << options_desc << endl;
	exit(1);
}

typedef enum { OFF = 0, BY_READ_GROUP = 1, BY_SAMPLE = 2 } separation_t;

std::istream& operator>>(std::istream& in, separation_t& s) {
	std::string token;
	in >> token;
	if (token.compare("off") == 0) {
		s = OFF;
	} else if (token.compare("readgroup") == 0) {
		s = BY_READ_GROUP;
	} else if (token.compare("sample") == 0) {
		s = BY_SAMPLE;
	} else {
		throw runtime_error("Invalid argument to option -D");
	}
	return in;
}

bool contains_indels(const BamTools::BamAlignment& alignment) {
	vector<BamTools::CigarOp>::const_iterator it = alignment.CigarData.begin();
	for (;it!=alignment.CigarData.end(); ++it) {
		switch (it->Type) {
		case 'M':
		case '=':
		case 'X':
			break;
		default:
			return true;
		}
	}
	return false;
}

bool is_uniquely_mapped(const BamTools::BamAlignment& alignment) {
	if (!alignment.IsMapped()) return false;
	if (alignment.MapQuality == 0) return false;
	if (!alignment.IsPrimaryAlignment()) return false;
	uint32_t x0_tag = -1;
	uint32_t x1_tag = -1;
	if (alignment.GetTag("X0",x0_tag)) {
		if (x0_tag>1) return false;
	}
	if (alignment.GetTag("X1",x1_tag)) {
		if (x1_tag>0) return false;
	}
	return true;
}

typedef struct parameters_t {
	bool all;
	int min_insert_size;
	int max_insert_size;
	int min_mapq;
	int max_edit_distance;
	bool write_split;
	int split_length;
	bool read_groups_in_names;
	separation_t separation;
	ReadGroups* read_groups;
	parameters_t() : all(false), min_insert_size(-1), max_insert_size(-1), min_mapq(30), max_edit_distance(3), write_split(false), split_length(35), read_groups_in_names(false), separation(OFF), read_groups(0) {}
} parameters_t;

typedef struct output_streams_t {
	ofstream* fastq_1_ofstream;
	io::filtering_ostream* fastq_1_os;
	ofstream* fastq_2_ofstream;
	io::filtering_ostream* fastq_2_os;
	ofstream* split_ofstream;
	io::filtering_ostream* split_os;

	void open_streams(ofstream** ofs, io::filtering_ostream** os, const string& filename, bool zip) {
		*ofs = new ofstream(filename.c_str());
		if ((*ofs)->fail()) {
			ostringstream oss;
			oss << "Error opening file \"" + filename + "\"." << endl;
			throw runtime_error(oss.str());
		}
		*os = new io::filtering_ostream;
		if (zip) (*os)->push(io::gzip_compressor());
		(*os)->push(**ofs);
	}
	
	void close() {
		if (fastq_1_os != 0) delete fastq_1_os;
		if (fastq_1_ofstream != 0) delete fastq_1_ofstream;
		if (fastq_2_os != 0) delete fastq_2_os;
		if (fastq_2_ofstream != 0) delete fastq_2_ofstream;
		if (split_os != 0) delete split_os;
		if (split_ofstream != 0) delete split_ofstream;
	}
	
	output_streams_t(const string& fastq_1_filename, const string& fastq_2_filename, const string& split_filename, bool zip) : fastq_1_ofstream(0), fastq_1_os(0), fastq_2_ofstream(0), fastq_2_os(0), split_ofstream(0), split_os(0) {
		open_streams(&fastq_1_ofstream, &fastq_1_os, fastq_1_filename, zip);
		open_streams(&fastq_2_ofstream, &fastq_2_os, fastq_2_filename, zip);
		if (split_filename.size() > 0) {
			open_streams(&split_ofstream, &split_os, split_filename, zip);
		}
	}
} output_streams_t;

typedef struct work_package_t {
	vector<BamTools::BamAlignment*>* alignments1;
	vector<BamTools::BamAlignment*>* alignments2;
	ostringstream output1;
	ostringstream output2;
	ostringstream split_output;
	const parameters_t& parameters;
	bool use_read;
	bool discordant_insert_size;
	int output_index;
	
	work_package_t(vector<BamTools::BamAlignment*>* alignments1, vector<BamTools::BamAlignment*>* alignments2, const parameters_t& parameters) : alignments1(alignments1), alignments2(alignments2), parameters(parameters), use_read(false), discordant_insert_size(false), output_index(0) {}

	~work_package_t() {
		free(alignments1);
		free(alignments2);
	}

	void free(vector<BamTools::BamAlignment*>* a) {
		assert(a != 0);
		for (size_t i=0; i<a->size(); ++i) {
			assert(a->at(i) != 0);
			delete a->at(i);
		}
		delete a;
	}

	void write_aln(ostringstream& os, const BamTools::BamAlignment* aln, int nr) {
		ShortDnaSequence seq(aln->QueryBases, aln->Qualities);
		if (aln->IsMapped() && aln->IsReverseStrand()) {
			seq = seq.reverseComplement();
		}
		string rg = "";
		if (parameters.read_groups_in_names) {
			if (!aln->GetTag("RG", rg)) {
				rg = "-1";
			}
			os << "@" << rg << "_" << aln->Name << endl;
		} else {
			os << "@" << aln->Name << endl;
		}
		os << seq.toString() << endl;
		os << "+" << endl;
		os << seq.qualityString() << endl;
		if (parameters.write_split) {
			if ((int)seq.size() < parameters.split_length) {
				ostringstream oss;
				oss << "Error: read \"" << aln->Name << "\" too short (" << seq.size() << "bp)." << endl;
				throw std::runtime_error(oss.str());
			}
			// Write left end of read
			if (parameters.read_groups_in_names) {
				split_output << "@" << rg << "_" << aln->Name << "_" << nr << "L" << endl;
			} else {
				split_output << "@" << aln->Name << "_" << nr << "L" << endl;
			}
			split_output << seq.toString().substr(0, parameters.split_length) << endl;
			split_output << "+" << endl;
			split_output << seq.qualityString().substr(0, parameters.split_length) << endl;
			// Write right end of read
			if (parameters.read_groups_in_names) {
				split_output << "@" << rg << "_" << aln->Name << "_" << nr << "R" << endl;
			} else {
				split_output << "@" << aln->Name << "_" << nr << "R" << endl;
			}
			split_output << seq.toString().substr(seq.size()-parameters.split_length, parameters.split_length) << endl;
			split_output << "+" << endl;
			split_output << seq.qualityString().substr(seq.size()-parameters.split_length, parameters.split_length) << endl;
		}
	}

	int get_insert_size(const BamTools::BamAlignment* aln1, const BamTools::BamAlignment* aln2) {
		const BamTools::BamAlignment* left = (aln1->Position < aln2->Position)?aln1:aln2;
		const BamTools::BamAlignment* right = (aln1->Position < aln2->Position)?aln2:aln1;
		return right->Position - left->GetEndPosition() - 1;
	}

	int get_edit_distance(const BamTools::BamAlignment* aln) {
		uint32_t nm = 0;
		if (aln->GetTag("NM", nm)) {
			return nm;
		}
		return -1;
	}

	void run() {
		assert(alignments1 != 0);
		assert(alignments2 != 0);
		if ((alignments1->size() == 0) || (alignments2->size() == 0)) {
			return;
		}
		BamTools::BamAlignment* aln1 = alignments1->at(0);
		BamTools::BamAlignment* aln2 = alignments2->at(0);
		assert(aln1 != 0);
		assert(aln2 != 0);
		if (parameters.all) {
			use_read = true;
		} else {
			if (!aln1->IsMapped()) use_read = true;
			if (!aln2->IsMapped()) use_read = true;
			if (alignments1->size() > 1) use_read = true;
			if (alignments1->size() > 1) use_read = true;
			if (!is_uniquely_mapped(*aln1)) use_read = true;
			if (!is_uniquely_mapped(*aln2)) use_read = true;
			if (contains_indels(*aln1)) use_read = true;
			if (contains_indels(*aln2)) use_read = true;
			if (aln1->MapQuality < parameters.min_mapq) use_read = true;
			if (aln2->MapQuality < parameters.min_mapq) use_read = true;
			if (get_edit_distance(aln1) > parameters.max_edit_distance) use_read = true;
			if (get_edit_distance(aln2) > parameters.max_edit_distance) use_read = true;
		}
		if (!use_read && (parameters.min_insert_size != -1) && (parameters.max_insert_size != -1)) {
			int insert_size = get_insert_size(aln1, aln2);
			if (insert_size < parameters.min_insert_size) {
				use_read = true;
				discordant_insert_size = true;
			}
			if (insert_size > parameters.max_insert_size) {
				use_read = true;
				discordant_insert_size = true;
			}
		}
		if (use_read) {
			if (parameters.read_groups != 0) {
				output_index = parameters.read_groups->getIndex(*aln1);
				if (output_index < 0) {
					ostringstream oss;
					oss << "Unknown read group for read " << aln1->Name;
					throw runtime_error(oss.str());
				}
			}
			write_aln(output1, aln1, 1);
			write_aln(output2, aln2, 2);
		}
	}
} work_package_t;
 
typedef struct output_writer_t {
	vector<output_streams_t>& output_streams;
	long long total_count;
	long long extracted_count;
	long long discordant_insert_size_count;

	void write(auto_ptr<work_package_t> work) {
		assert(work.get() != 0);
		assert(work->output_index >= 0);
		assert(work->output_index < output_streams.size());
		output_streams_t& out = output_streams[work->output_index];
		total_count += 1;
		if (work->use_read) {
			(*out.fastq_1_os) << work->output1.str();
			(*out.fastq_2_os) << work->output2.str();
			if (out.split_os != 0) {
				(*out.split_os) << work->split_output.str();
			}
			extracted_count += 1;
			if (work->discordant_insert_size) {
				discordant_insert_size_count += 1;
			}
		}
	}
	output_writer_t(vector<output_streams_t>& output_streams) : output_streams(output_streams), total_count(0), extracted_count(0), discordant_insert_size_count(0)  {}
} output_writer_t;

int main(int argc, char* argv[]) {
	VersionInfo::checkAndPrintVersion("extract-bad-reads", cerr);
	string commandline = VersionInfo::commandline(argc, argv);

	// PARAMETERS
	int max_span;
	int threads;
	string split_filename = "";
	parameters_t parameters;
	bool zip_output = false;
	bool unsorted = false;
	bool use_hard_clipped = false;
	
	po::options_description options_desc("Allowed options");
	options_desc.add_options()
		("all,a", po::value<bool>(&parameters.all)->zero_tokens(), "Extract all reads instead of only \"bad\" ones.")
		("unsorted,u", po::value<bool>(&unsorted)->zero_tokens(), "Input is not sorted by position but grouped by readname, i.e., all alignments of a read pair are in subsequent lines.")
		("max_span,s", po::value<int>(&max_span)->default_value(50000), "Maximal internal segment. Read pairs with larger internal segment will be ignored.")
		("threads,T", po::value<int>(&threads)->default_value(0), "Number of threads (default: 0 = strictly single-threaded).")
		("min_insert_size,m", po::value<int>(&parameters.min_insert_size)->default_value(-1), "Minimum internal segment size (excluding reads) for a pair to be considered good (and thus not be extracted).")
		("max_insert_size,M", po::value<int>(&parameters.max_insert_size)->default_value(-1), "Maximum internal segment size (excluding reads) for a pair to be considered good (and thus not be extracted).")
		("min_mapq,Q", po::value<int>(&parameters.min_mapq)->default_value(30), "Minimum mapping quality for good reads (which are not to be extracted).")
		("max_edit_distance,e", po::value<int>(&parameters.max_edit_distance)->default_value(3), "Maximum allowed edit distance, reads with larger distance will be extracted.")
		("split_file,S", po::value<string>(&split_filename)->default_value(""), "Filename to write (gzipped) split reads to (FASTQ format).")
		("split_length,l", po::value<int>(&parameters.split_length)->default_value(35), "Length of prefix/suffix to be extracted (if option -S is used).")
		("read_groups,r", po::value<bool>(&parameters.read_groups_in_names)->zero_tokens(), "Encode read groups in read_names (as \"<readgroup>_<name>\").")
		("distribute_output,D", po::value<separation_t>(&parameters.separation)->default_value(OFF), "Distribute output over multiple files according to [off|readgroup|sample].")
		("use_hard_clipped,H", po::value<bool>(&use_hard_clipped)->zero_tokens(), "Also use hard clipped reads (default: ignore hard clipped reads).")
	;
	
	if (argc<3) {
		usage(argv[0], options_desc);
	}
	string fastq_1_filename(argv[argc-2]);
	string fastq_2_filename(argv[argc-1]);
	argc -= 2;

	po::variables_map options;
	try {
		po::store(po::parse_command_line(argc, argv, options_desc), options);
		po::notify(options);
	} catch(exception& e) {
		cerr << "error: " << e.what() << "\n";
		return 1;
	}
	cerr << "Commandline: " << commandline << endl;

	if ((fastq_1_filename.substr(max(0,(int)fastq_1_filename.size()-3),3).compare(".gz") == 0) && (fastq_2_filename.substr(max(0,(int)fastq_2_filename.size()-3),3).compare(".gz") == 0)) {
		zip_output = true;
	}

	clock_t clock_start = clock();
	if (split_filename.size() > 0) {
		parameters.write_split = true;
	}
	// count the number of alignments skipped because of missing XA tags
	BamReader* bam_reader = 0;
	typedef ThreadPool<work_package_t,output_writer_t> thread_pool_t;
	output_writer_t* output_writer = 0;
	vector<output_streams_t> output_streams;
	try {
		if (unsorted) {
			bam_reader = new GroupWiseBamReader("/dev/stdin", true, true, false, !use_hard_clipped);
		} else {
			bam_reader = new SortedBamReader("/dev/stdin", true, max_span, false, !use_hard_clipped);
		}
		bam_reader->enableProgressMessages(cerr, 200000);
		if (parameters.separation == OFF) {
			output_streams.push_back(output_streams_t(fastq_1_filename,fastq_2_filename,split_filename,zip_output));
		} else {
			if (!bam_reader->getHeader().HasReadGroups()) {
				cerr << "No read group information found in BAM input." << endl;
				return 1;
			}
			parameters.read_groups = new ReadGroups(bam_reader->getHeader().ReadGroups, parameters.separation == BY_SAMPLE);
			for (size_t i=0; i<parameters.read_groups->size(); ++i) {
				ostringstream f1;
				f1 << fastq_1_filename << "." << parameters.read_groups->getName(i);
				ostringstream f2;
				f2 << fastq_2_filename << "." << parameters.read_groups->getName(i);
				if (split_filename.size() == 0) {
					output_streams.push_back(output_streams_t(f1.str(),f2.str(),"",zip_output));
				} else {
					ostringstream split;
					split << split_filename << "." << parameters.read_groups->getName(i);
					output_streams.push_back(output_streams_t(f1.str(),f2.str(),split.str(),zip_output));
				}
			}
		}
		output_writer = new output_writer_t(output_streams);
		thread_pool_t thread_pool(threads, 1000, threads, *output_writer);
		while ( bam_reader->hasNext() ) {
			bam_reader->advance();
			// if (bam_reader->isFirstUnmapped() || bam_reader->isSecondUnmapped()) continue;
			vector<BamTools::BamAlignment*>* alignments1 = bam_reader->releaseAlignmentsFirst().release();
			vector<BamTools::BamAlignment*>* alignments2 = bam_reader->releaseAlignmentsSecond().release();
			thread_pool.addTask(auto_ptr<work_package_t>(new work_package_t(alignments1, alignments2, parameters)));
		}
	} catch(exception& e) {
		cerr << "Error: " << e.what() << "\n";
		return 1;
	}
	if (bam_reader->getNonPairedCount() > 0) {
		cerr << "Skipped " << bam_reader->getNonPairedCount() << " reads with no matching mate alignment." << endl;
	}
	if (bam_reader->getSkippedDuplicates() > 0) {
		cerr << "Skipped " << bam_reader->getSkippedDuplicates() << " duplicate reads." << endl;
	}
	cerr << "Extracted " << output_writer->extracted_count << " out of " << output_writer->total_count << " total reads." << endl;
	cerr << "Extracted " << output_writer->discordant_insert_size_count << " (solely) because internal segment size was discordant." << endl;
	for (size_t i=0; i<output_streams.size(); ++i) {
		output_streams[i].close();
	}
	delete output_writer;
	if (bam_reader != 0) delete bam_reader;
	double cpu_time = (double)(clock() - clock_start) / CLOCKS_PER_SEC;
	cerr << "Total CPU time: " << cpu_time << endl;
	return 0;
}
